package com.diven.spark.ml.learn.feature

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.BucketedRandomProjectionLSH
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.BucketedRandomProjectionLSHModel
import com.diven.spark.ml.learn.core.BaseTest
import com.diven.spark.ml.learn.core.BaseSpark

object BucketedRandomProjectionLSHTest extends BaseTest {
    
    def apply(): BaseSpark = new BucketedRandomProjectionLSHTest()
    
}

class BucketedRandomProjectionLSHTest extends BaseSpark {
  
    override def execute(dataFrame: DataFrame) = {
        //特征名称
        var features = Array("weight", "height", "age")
        //字段转换成特征向量
        var splitDatas = new VectorAssembler()
        .setInputCols(features)
        .setOutputCol("vector_features")
        .transform(dataFrame.select("id", features:_*))
        .randomSplit(Array(0.4, 0.3, 0.3))
        //训练模型
        var model:BucketedRandomProjectionLSHModel = new BucketedRandomProjectionLSH()
        .setInputCol("vector_features")         //待变换的特征
        .setOutputCol("bkt_lsh")                //变换后的特征名称
        .setBucketLength(10d)                   //每个哈希桶的长度，更大的桶降低了假阴性率
        .setNumHashTables(5)                    //哈希表的数量，散列表数量的增加降低了错误的否定率，并且降低它提高了运行性能
        .setSeed(100L)                          //随机种子
        .fit(splitDatas.apply(0))               //训练
        //通过模型转换数据
        var transform = model.transform(splitDatas.apply(0))
        transform.show(10, 100)
        transform.printSchema()
        //推荐信息,获取相关性较高的数据
        var recommend= model.approxSimilarityJoin(splitDatas.apply(1), splitDatas.apply(2), 2, "distCol")
        .select(
            col("datasetA").getField("id").as("id"), 
            col("datasetB").getField("id").as("recommend_id"), 
            col("datasetA").getField("age").as("age"), 
            col("datasetB").getField("age").as("recommend_age"), 
            col("datasetA").getField("weight").as("weight"), 
            col("datasetB").getField("weight").as("recommend_weight"), 
            col("datasetA").getField("height").as("height"), 
            col("datasetB").getField("height").as("recommend_height"),
            col("distCol")
        )
        recommend.orderBy("id", "distCol").show(100, 1000)
    }
    
}